"""Load/unload cutlery to/from dishwasher."""
from abc import ABC

import numpy as np
from pyquaternion import Quaternion

from bigym.bigym_env import BiGymEnv, MAX_DISTANCE_FROM_TARGET
from bigym.const import HandSide
from bigym.envs.props.cutlery import Fork, Knife
from bigym.envs.props.dishwasher import Dishwasher
from bigym.envs.props.holders import CutleryTray
from bigym.envs.props.cabintets import BaseCabinet, BaseCabinetForCutlery
from bigym.envs.props.tableware import Mug
from bigym.utils.env_utils import get_random_sites
from bigym.utils.physics_utils import distance


TABLE_1_POS = np.array([1, 0, 0])
TABLE_1_ROT = np.array([0, 0, -np.pi / 2])
TABLE_2_POS = np.array([1, -0.6, 0])
TABLE_2_ROT = np.array([0, 0, -np.pi / 2])

DISHWASHER_POS = np.array([1, 0, 0])
DISHWASHER_ROT = np.array([0, 0, -np.pi / 2])


class _DishwasherCutleryEnv(BiGymEnv, ABC):
    """Base cutlery environment."""

    _DEFAULT_ROBOT_POS = np.array([0, -0.6, 1])
    _USE_STABLE_GRIPPER = True

    _CUTLERY = [Knife, Fork]

    def _initialize_env(self):
        self.cabinet_1: BaseCabinet = BaseCabinet(self._mojo, walls_enable=False)
        self.cabinet_2: BaseCabinet = BaseCabinet(self._mojo, panel_enable=True)
        self.dishwasher: Dishwasher = Dishwasher(self._mojo)
        self.cutlery = [item_cls(self._mojo) for item_cls in self._CUTLERY]

        self.cabinet_1.body.set_position(TABLE_1_POS)
        self.cabinet_1.body.set_euler(TABLE_1_ROT)
        self.cabinet_2.body.set_position(TABLE_2_POS)
        self.cabinet_2.body.set_euler(TABLE_2_ROT)
        self.dishwasher.body.set_position(DISHWASHER_POS)
        self.dishwasher.body.set_euler(DISHWASHER_ROT)

    def _fail(self) -> bool:
        if (
            distance(self._robot.pelvis, self.dishwasher.body)
            > MAX_DISTANCE_FROM_TARGET
        ):
            return True
        for item in self.cutlery:
            if item.is_colliding(self.floor):
                return True
        return False

    def _on_reset(self):
        self.dishwasher.set_state(door=1, bottom_tray=1, middle_tray=0)


class _DishwasherUnloadCutleryEnv(_DishwasherCutleryEnv):
    """Base unload cutlery from dishwasher task."""

    _SITES_SLICE = -2

    _CUTLERY_OFFSET_POS = np.array([0, 0, 0.1])
    _CUTLERY_BOUNDS_ANGLE = np.deg2rad(30)
    _CUTLERY_SPAWN_ROT = Quaternion(axis=[1, 0, 0], degrees=90)

    def _get_task_privileged_obs_space(self):
        return {}

    def _get_task_privileged_obs(self):
        return {}

    def _on_reset(self):
        super()._on_reset()
        sites = self.dishwasher.basket.site_sets[0]
        sites = get_random_sites(sites, len(self.cutlery), segment=self._SITES_SLICE)
        for site, item in zip(sites, self.cutlery):
            item_pos = site.get_position() + self._CUTLERY_OFFSET_POS
            angle = np.random.uniform(
                -self._CUTLERY_BOUNDS_ANGLE, self._CUTLERY_BOUNDS_ANGLE
            )
            item_quat = self._CUTLERY_SPAWN_ROT * Quaternion(
                axis=[0, 1, 0], angle=angle
            )
            item.body.set_quaternion(item_quat.elements, True)
            item.body.set_position(item_pos, True)


class DishwasherUnloadCutlery(_DishwasherUnloadCutleryEnv):
    """Unload cutlery from dishwasher task."""

    _TRAY_POS = np.array([0.65, -0.6, 0.86])
    _TRAY_BOUNDS = np.array([0.05, 0.05, 0])
    _TRAY_ROT = np.array([0, 0, -np.pi / 2])

    def _initialize_env(self):
        super()._initialize_env()
        self.tray = CutleryTray(self._mojo)

    def _success(self) -> bool:
        for item in self.cutlery:
            if not item.is_colliding(self.tray):
                return False
            for side in HandSide:
                if self.robot.is_gripper_holding_object(item, side):
                    return False
        return True

    def _on_reset(self):
        super()._on_reset()
        offset = np.random.uniform(-self._TRAY_BOUNDS, self._TRAY_BOUNDS)
        self.tray.body.set_position(self._TRAY_POS + offset)
        self.tray.body.set_euler(self._TRAY_ROT)


class DishwasherUnloadCutleryLong(_DishwasherUnloadCutleryEnv):
    """Unload cutlery from dishwasher to drawer task."""

    _CUTLERY = [Fork]

    _CABINET_POS = np.array([1, -1.2, 0])
    _CABINET_ROT = np.array([0, 0, -np.pi / 2])

    _TOLERANCE = 0.1

    def _initialize_env(self):
        super()._initialize_env()
        self.cutlery_cabinet = BaseCabinetForCutlery(self.mojo)
        self.cutlery_cabinet.body.set_position(self._CABINET_POS)
        self.cutlery_cabinet.body.set_euler(self._CABINET_ROT)

    def _success(self) -> bool:
        if not np.allclose(self.dishwasher.get_state(), 0, atol=self._TOLERANCE):
            return False
        if not np.allclose(self.cutlery_cabinet.get_state(), 0, atol=self._TOLERANCE):
            return False
        for item in self.cutlery:
            if not item.is_colliding(self.cutlery_cabinet.tray):
                return False
            for side in HandSide:
                if self.robot.is_gripper_holding_object(item, side):
                    return False
        return True


class DishwasherLoadCutlery(_DishwasherCutleryEnv):
    """Load cutlery to dishwasher task."""

    _MUG_POS = np.array([0.65, -0.6, 0.86])
    _MUG_BOUNDS = np.array([0.05, 0.05, 0])
    _MUG_BOUNDS_ANGLE = np.deg2rad(90)

    _BASKET_OFFSET_POS = np.array([0, 0, 0.15])
    _CUTLERY_SPAWN_ROT = Quaternion(axis=[1, 0, 0], degrees=90)
    _CUTLERY_OFFSET_ANGLE = np.deg2rad(90)
    _CUTLERY_OFFSET_ANGLE_RANGE = np.deg2rad(5)
    _CUTLERY_SPAWN_OFFSET = 0.02

    def _initialize_env(self):
        super()._initialize_env()
        self.mug = Mug(self._mojo, False)

    def _get_task_privileged_obs_space(self):
        return {}

    def _get_task_privileged_obs(self):
        return {}

    def _success(self) -> bool:
        for item in self.cutlery:
            if not item.is_colliding(self.dishwasher.basket.colliders):
                return False
            for side in HandSide:
                if self.robot.is_gripper_holding_object(item, side):
                    return False
        return True

    def _on_reset(self):
        super()._on_reset()
        mug_angle = np.random.uniform(-self._MUG_BOUNDS_ANGLE, self._MUG_BOUNDS_ANGLE)
        self.mug.body.set_euler(np.array([0, 0, mug_angle]))
        offset = np.random.uniform(-self._MUG_BOUNDS, self._MUG_BOUNDS)
        mug_pos = self._MUG_POS + offset
        self.mug.body.set_position(mug_pos, True)
        for i, item in enumerate(self.cutlery):
            item.body.set_quaternion(self._CUTLERY_SPAWN_ROT.elements, True)
            offset_angle = self._CUTLERY_OFFSET_ANGLE * i
            offset_angle += np.random.uniform(
                -self._CUTLERY_OFFSET_ANGLE_RANGE, self._CUTLERY_OFFSET_ANGLE_RANGE
            )
            item_offset = np.array([np.cos(offset_angle), np.sin(offset_angle), 0])
            item_offset *= self._CUTLERY_SPAWN_OFFSET
            item_pos = mug_pos.copy() + self._BASKET_OFFSET_POS + item_offset
            item.body.set_position(item_pos, True)
